# Functions for importing SCTO --------------------------------------------

# Install remotes if not already installed (needed for version-specific installs)
if (!require("remotes", quietly = TRUE)) {
  install.packages("remotes")
}
library(remotes)

# Read r-requirements.txt
requirements_file <- file.path(dirname(sys.frame(1)$ofile), "r-requirements.txt")
if (!file.exists(requirements_file)) {
  requirements_file <- "code/r-requirements.txt"
}

if (file.exists(requirements_file)) {
  message("=== Checking package versions from r-requirements.txt ===")
  requirements <- readLines(requirements_file)
  requirements <- requirements[requirements != ""]  # Remove empty lines

  message(sprintf("Found %d package requirements to check", length(requirements)))

  for (req in requirements) {
    parts <- strsplit(req, "==")[[1]]
    pkg_name <- parts[1]

    if (length(parts) == 2) {
      required_version <- parts[2]

      # Skip packages marked as NOT_INSTALLED
      if (required_version == "NOT_INSTALLED") {
        message(sprintf("  [SKIP] %s - marked as NOT_INSTALLED", pkg_name))
        next
      }

      # Check if package is installed and get its version
      pkg_installed <- requireNamespace(pkg_name, quietly = TRUE)

      if (pkg_installed) {
        installed_version <- as.character(packageVersion(pkg_name))
        if (installed_version != required_version) {
          message(sprintf("  [MISMATCH] %s: installed=%s, required=%s", pkg_name, installed_version, required_version))
          message(sprintf("  [INSTALLING] Forcing %s version %s...", pkg_name, required_version))
          # Force install the correct version
          remotes::install_version(pkg_name, version = required_version, upgrade = "never", force = TRUE)
          message(sprintf("  [SUCCESS] %s version %s installed", pkg_name, required_version))
        } else {
          message(sprintf("  [OK] %s version %s", pkg_name, installed_version))
        }
      } else {
        message(sprintf("  [NOT INSTALLED] %s", pkg_name))
        message(sprintf("  [INSTALLING] %s version %s...", pkg_name, required_version))
        remotes::install_version(pkg_name, version = required_version, upgrade = "never")
        message(sprintf("  [SUCCESS] %s version %s installed", pkg_name, required_version))
      }
    }
  }

  message("=== Loading packages ===")
  # Now load all packages (except those marked NOT_INSTALLED)
  loaded_count <- 0
  for (req in requirements) {
    parts <- strsplit(req, "==")[[1]]
    pkg_name <- parts[1]
    if (length(parts) == 2 && parts[2] != "NOT_INSTALLED") {
      message(sprintf("  Loading %s...", pkg_name))
      suppressPackageStartupMessages(library(pkg_name, character.only = TRUE))
      loaded_count <- loaded_count + 1
    }
  }
  message(sprintf("=== Successfully loaded %d packages ===", loaded_count))
} else {
  warning("r-requirements.txt file not found. Using fallback package list.")

  # Fallback to original list
  list.of.packages <- c("tidyverse", "readxl", "knitr", "modelsummary", "rstatix",
                        "ggsignif", "ggpubr", "stringdist", "lubridate", "fixest",
                        "corrplot", "xtable", "kableExtra", "ggpattern", "tictoc",
                        "janitor", "hdm", "randomizr", "ri2", "ggplot2", "dplyr",
                        "haven", "ggh4x", "wordcloud", "RColorBrewer", "psych",
                        "stringr", "sandwich", "lmtest", "ICC", "remotes")
  new.packages <- list.of.packages[!(list.of.packages %in% installed.packages()[,"Package"])]
  if(length(new.packages)) install.packages(new.packages)
  invisible(lapply(list.of.packages, library, character.only = TRUE))
}
options(stringsAsFactors=FALSE, digits=2)

# MAKE GGSAVE default to device = cairo_pdf
ggsave <- function(...) {
  ggplot2::ggsave(..., device = cairo_pdf)
}

import_survey_questions <- function(survey_file_path) {

  read_excel(survey_file_path, "survey") %>%
    select(name, type, starts_with("label")) %>%
    select(-any_of("label:data")) %>%
    select(1:3) %>%
    set_names(c("name", "type", "label")) %>%
    select(name, label, type) %>%
    separate(type, into = c("type", "choices"), sep = " ", remove = TRUE) %>%
    filter(!is.na(name)) %>%
    print

}



fct_label_append <- function(x, lab_append) {
  levs <- levels(x)
  x_out <- paste0(x, lab_append)
  # Get the ascending index of each level
  x_int <- x %>% unique() %>% as.integer()
  new_levels <- x_out %>% unique()

  new_levels_sorted <- x_int %>% set_names(new_levels) %>%
    sort() %>%
    names()

  x_out_fct <- x_out %>% factor(levels = new_levels_sorted)
}

fml_to_vars <- function(x) {
  x %>%
    as.character() %>%
    str_replace_all("factor\\(.*\\)", "\1") %>%
    str_split(boundary("word")) %>%
    unlist() %>%
    str_split("\\:") %>%
    unlist() %>%
    .[. != "0"] %>%
    unique()
}


import_survey_choices <- function(survey_file_path) {

  read_excel(survey_file_path, sheet = "choices") %>%
    mutate(value = case_when(
      !is.na(as.numeric(value)) ~ as.character(as.numeric(value)),
      is.na(as.numeric(value)) & !is.na(value) ~ as.character(value)
    )) %>%
    select(list_name, value, starts_with("label")) %>%
    select(1:3) %>%
    set_names(c("list_name", "value", "label")) %>%
    filter(!is.na(list_name)) %>%
    mutate(label = as.character(str_glue("{value}: {label}")))

}

# Label as factor one variable
label_factor_one <- function(data, var, survey_questions, survey_choices, .suffix = "_lab") {

  print(var)

  choice_list <- survey_questions %>% filter(name == var) %>% pull(choices)
  choice_labels <- survey_choices %>% filter(list_name == choice_list) %>%
    select(label, value)

  data[[paste0(var, .suffix)]] <- factor(data[[var]], choice_labels$value, choice_labels$label)
  data

}

# Label as factor multiple variables (chosen using a vector)
label_factor_many <- function(data, vars, survey_questions, survey_choices, .suffix = "_lab") {

  for (i in 1:length(vars)) {
    data <- label_factor_one(data, vars[[i]], survey_questions, survey_choices, .suffix = .suffix)
  }
  data
}


# Label all selectone types as a factor
label_factor_all <- function(data, survey_questions, survey_choices, .suffix = .suffix) {

  select_one_vars <- survey_questions %>%
    filter(type == "select_one") %>%
    pull(name)

  label_factor_many(data = data,
                    vars = select_one_vars,
                    survey_questions = survey_questions,
                    survey_choices = survey_choices,
                    .suffix = .suffix)

}

# FOR SELECT_MULTIPLES
label_multi_one <- function(data, var, survey_questions, survey_choices, .suffix = "_lab", baseline_var = NULL) {

  if (is.null(baseline_var)) baseline_var <- var

  choice_list <- survey_questions %>% filter(name == baseline_var) %>% pull(choices)
  choice_labels <- survey_choices %>% filter(list_name == choice_list) %>%
    mutate(value = as.character(value)) %>%
    select(label, value)

  new_col <- data %>% select(!!var) %>% mutate(row_number_ = row_number()) %>%
    mutate(across(
      all_of(var),
      as.character
    )) %>%
    separate_rows(!!var, sep = " ", convert = FALSE) %>%
    rename(value = !!var) %>%
    mutate(value = as.character(value)) %>%
    left_join(choice_labels, by = c("value")) %>%
    group_by(row_number_) %>%
    summarise(var = paste0(label, collapse = "; ")) %>%
    mutate(var = ifelse(var == "NA", NA, var)) %>%
    pull(var)

  data[[paste0(var, .suffix)]] <- new_col

  data
}

# Label multiple variables (chosen using a vector)
label_multi_many <- function(data, vars, survey_questions, survey_choices, .suffix = "_lab", baseline_vars = NULL) {

  if (is.null(baseline_vars)) baseline_vars <- vars

  for (i in 1:length(vars)) {
    data <- label_multi_one(data, vars[[i]], survey_questions, survey_choices, .suffix = .suffix, baseline_var = baseline_vars[[i]])
  }
  data
}


# Label all select_multiple types
label_multi_all <- function(data, survey_questions, survey_choices, .suffix = .suffix) {

  select_multi_vars <- survey_questions %>%
    filter(type == "select_multiple") %>%
    select(name, type) %>%
    print_all

  # COUNT HOW MANY (nested repeat groups they are in)

  # Get regex to pick up all select_multis and repeats (regardless of format)
  select_multi_var_list <- select_multi_vars %>%
    pull(name) %>%
    paste0("^", .) %>%
    paste0(., "($|__?\\d)") %>%
    paste0(., collapse = "|") %>%
    print

  possible_names <- data %>% select(matches(select_multi_var_list)) %>% names() %>% enframe() %>%
    separate(value, into = paste0("value_", 1:5), sep = "_", remove = FALSE) %>%
    rename(var_name = value) %>%
    unite("value_1_2", value_1, value_2, remove = FALSE) %>%
    unite("value_1_2_3", value_1, value_2, value_3, remove = FALSE) %>%
    unite("value_1_2_3_4", value_1, value_2, value_3, value_4, remove = FALSE) %>%
    unite("value_1_2_3_4_5", value_1, value_2, value_3, value_4, value_5, remove = FALSE) %>%
    relocate(var_name, value_1) %>%
    select(-c(value_2:value_5)) %>%
    select(-name) %>%

    left_join(select_multi_vars %>% mutate(match = 1), by = c("value_1" = "name"), suffix = c("", "_1")) %>%
    left_join(select_multi_vars %>% mutate(match = 2), by = c("value_1_2" = "name"), suffix = c("", "_2")) %>%
    left_join(select_multi_vars %>% mutate(match = 3), by = c("value_1_2_3" = "name"), suffix = c("", "_3")) %>%
    left_join(select_multi_vars %>% mutate(match = 4), by = c("value_1_2_3_4" = "name"), suffix = c("", "_4")) %>%
    left_join(select_multi_vars %>% mutate(match = 5), by = c("value_1_2_3_4_5" = "name"), suffix = c("", "_5")) %>%
    mutate(
      match = coalesce(match, match_2, match_3, match_4, match_5)
    ) %>%
    select(var_name, match, value_1:value_1_2_3_4_5) %>%
    pivot_longer(value_1:value_1_2_3_4_5) %>%
    group_by(var_name, match) %>%
    mutate(row_i = row_number()) %>%
    filter(row_i == match) %>%
    rename(baseline_var = value) %>%
    select(-row_i) %>%
    print


  # PATTERN WITH REPEATS IS
  # [question]_[repeatindex]
  # and
  # [question]_[choiceval]_[repeatindex]

  # WORK OUT IF REPEAT OR NOT
  repeat_index <- possible_names %>%
    mutate(
      extra_index_1 = str_detect(var_name, paste0(baseline_var, "__?\\d+")),
      extra_index_2 = str_detect(var_name, paste0(baseline_var, "__?\\d+__?\\d+")),
      extra_index = 0 + extra_index_1 + extra_index_2
    ) %>%
    ungroup %>%
    count_prop(extra_index_1, extra_index_2) %>%
    group_by(baseline_var) %>%
    mutate(max_extra_index = max(extra_index),
           min_extra_index = min(extra_index)) %>%
    count_prop(min_extra_index) %>%  # MIN EXTRA INDEX IDENTIFIES NUMBER OF REPEAT GROUPS
    print

  # COLLECT ALL VARS without the choiceval part
  select_multi_vars_no_choiceval <- repeat_index %>%
    filter(extra_index == min_extra_index) %>%
    arrange(var_name)


  label_multi_many(data = data,
                   vars = select_multi_vars_no_choiceval$var_name,
                   survey_questions = survey_questions,
                   survey_choices = survey_choices,
                   .suffix = .suffix,
                   baseline_vars = select_multi_vars_no_choiceval$baseline_var)

}


# Label BOTH select_one and select_multiple
label_all <- function(data, survey_questions, survey_choices, .suffix = .suffix) {
  data %>%
    label_multi_all(survey_questions, survey_choices, .suffix) %>%
    label_factor_all(survey_questions, survey_choices, .suffix)

}

hist_basic <- function(dat, x, binwidth = NULL, fill = NULL, boundary = 0, x_range = NULL) {

  if (possibly(is.null, FALSE)(fill) == TRUE) {
    plot <- ggplot(dat, aes(x = {{x}})) +
      geom_histogram(boundary = boundary, binwidth = binwidth, colour = "grey", fill = "lightblue")

  } else {
    plot <- ggplot(dat, aes(x = {{x}}, fill = {{fill}})) +
      geom_histogram(boundary = boundary, binwidth = binwidth, colour = "grey")
  }

  if (!is.null(x_range)) plot <- plot + coord_cartesian(xlim = x_range)

  return(plot)

}

tidy_90 <- function(x) {

  tidy_95 <- broom::tidy(x, conf.int = TRUE, conf.level = 0.95)
  tidy_90 <- broom::tidy(x, conf.int = TRUE, conf.level = 0.9) %>%
    rename(conf.high_90 = conf.high,
           conf.low_90 = conf.low) %>%
    select(term, conf.high_90, conf.low_90)

  n <- nobs(x)

  full_join(tidy_95, tidy_90, by = c("term")) %>%
    mutate(n = n) %>%
    mutate(sig = case_when(
      p.value < 0.01 ~ "***",
      p.value < 0.05 ~ "**",
      p.value < 0.1 ~ "*",
      TRUE ~ ""
    )) %>%
    relocate(sig, .after = p.value)
}

# Basic plot of coefficients 
coef_plot <- function(model, exclude = NULL, include = NULL, var_labels = NULL) {
  n <- nobs(model)

  tidy_df <- model %>% tidy_90

  if (!is.null(exclude)) tidy_df <- tidy_df %>% filter(!str_detect(term, exclude))
  if (!is.null(include)) tidy_df <- tidy_df %>% filter(str_detect(term, include))

  plot <- ggplot(tidy_df, aes(x = term)) +
    geom_hline(yintercept = 0, linetype = "dashed", colour = "skyblue") +
    geom_errorbar(aes(ymin = conf.low_90, ymax = conf.high_90), colour = "#636363", width = 0, size = 1.2) +
    geom_errorbar(aes(ymin = conf.low, ymax = conf.high), colour = "#636363", width = 0, size = 0.7) +
    geom_point(aes(y = estimate, colour = p.value <= 0.05), show.legend = F, size = 3) +
    coord_flip() +
    theme_minimal() +
    theme(panel.grid.major.y = element_blank(),
          panel.grid.minor.y = element_blank(),
          panel.grid.minor.x = element_blank()) +
    labs(y = str_glue("Estimate (N = {n})"))

  plot
}


# TABLE FOR YES-NO answers
kable_yes_no <- function(data, ..., questions) {
  data %>%
    select(...) %>%
    pivot_longer(cols = -any_of(group_vars(.))) %>%
    left_join(questions, by = "name") %>%
    mutate(name = factor(name, levels = unique(name))) %>%
    group_by(name, label, .add = TRUE) %>%
    summarise(
      tibble(
        yes = sum_na(value == 1),
        no  = sum_na(value == 0),
        refuse = sum_na(value == -99),
        dk = sum_na(value == -98),
        prop_yes = yes / sum(!is.na(value)),
        total = sum(!is.na(value))
      )
    ) %>%

    purrr::discard(~all(.x == 0, na.rm = TRUE)) %>%

    ungroup %>%
    arrange(name) %>%
    mutate(
      across(-any_of(c("yes", "no", "refuse", "dk", "prop_yes", "total")), ~ ifelse(.x == lag(.x) & !is.na(lag(.x)), "", .x))
    ) %>%
    kable(digits = 2)
}


mean_ci <- function(x) {
  if (length(x[!is.na(x)]) == 0) {
    return(NA)
  }

  d <- data.frame(val = x)

  lm(val ~ 1, data = d) %>%
    broom::tidy(conf.int = TRUE) %>%
    select(y = estimate, ymin = conf.low, ymax = conf.high)
}

bar_chart <- function(data, x, y,
                      fill = NULL, facet = NULL, error_bar = TRUE, percent = FALSE,
                      n_label = FALSE, flip = FALSE,
                      width = 0.8,
                      val_label = FALSE,
                      xlim = NULL, return_data = FALSE,
                      label_size = 8
) {

  d <- data %>%
    group_by({{x}}, {{fill}}, {{facet}}) %>%
    summarise(mean_cl_boot({{y}}),
              n = n()) %>%
    mutate(n_lab = str_glue("N={n}")) %>%
    print %>%
    ungroup

  if (return_data) return(d)

  y_lab <- deparse(substitute(y))

  fill_yn <- deparse(substitute(fill)) != "NULL"
  facet_yn <- deparse(substitute(facet)) != "NULL"

  if (!fill_yn) {
    p <- d %>%
      ggplot(aes(x = {{x}}, y = y, ymin = ymin, ymax = ymax)) +
      geom_col()

    if (error_bar) p <- p + geom_errorbar(width = 0.2)
  } else {
    p <- d %>%
      ggplot(aes(x = {{x}}, y = y, ymin = ymin, ymax = ymax, fill = {{fill}})) +
      geom_col(position = position_dodge(width), width = width)

    if (error_bar) p <- p + geom_errorbar(position = position_dodge(width), width = 0.2)
  }

  if (facet_yn) {
    facet_fml <- as.formula(paste0("~ ", deparse(substitute(facet))))
    p <- p + facet_wrap(facet_fml)
  }

  p <- p + labs(y = y_lab)


  if (percent) {
    p <- p + scale_y_continuous(labels = scales::percent)
  }

  if (n_label) {
    max_y <- max(c(d$y, d$ymin, d$ymax), na.rm = TRUE)
    if (fill_yn) {
      p <- p + geom_text(data = d, aes(y = 0.03 * max_y, x = {{x}}, label = n_lab), position = position_dodge(0.8), size = label_size)
    } else {
      p <- p + geom_text(data = d, aes(y = 0.03 * max_y, x = {{x}}, label = n_lab), size = label_size)
    }

  }

  if (val_label) {
    if (fill_yn) {
      p <- p + geom_text(data = d, aes(y = y/2, x = {{x}}, label = round(y, 2)), position = position_dodge(0.8), size = label_size)
    } else {
      p <- p + geom_text(data = d, aes(y = y/2, x = {{x}}, label = round(y, 2)), size = label_size)
    }
  }

  if (flip) {
    p <- p + coord_flip()
  }

  return(p)

}

remove_emojis <- function(x) stringr::str_replace_all(x, "[^\x01-\x7F]", "")

split_multiple <- function(x) {
  x %>% str_split(" ") %>% map(as.numeric)
}

in_list <- function(x, l) {
  map2_lgl(x, l, ~ .x %in% .y)
}


coalesce_drop <- function(df, col1, col2){
  df %>% mutate({{col1}} := coalesce({{col1}}, {{col2}})) %>% select(-{{col2}})
}

basic_bar <- function(data, y, x, label = TRUE) {
  p <- data %>% ggplot(aes(x = {{x}}, y = {{y}})) + geom_col() + coord_flip()

  if (label) p <- p + geom_label(aes(label = str_glue("{round({{y}}, 0)}")))

  p
}


leading_zeros <- function(x, width = 3) {
  if (sum(nchar(as.character(x)) > width) > 0) stop("Something's wider than the width")
  else formatC(as.numeric(x), width = width, flag = "0")
}


count_nas_quiet <- function (data, sort = FALSE, return_count = FALSE) {
  nas <- colSums(is.na(data)) %>% tibble::enframe()
  vals <- colSums(!is.na(data)) %>% tibble::enframe()
  both <- nas %>% dplyr::inner_join(vals, by = "name") %>%
    dplyr::select(variable = name, missing = value.x, non_missing = value.y) %>%
    dplyr::mutate(prop_missing = missing/(missing + non_missing))
  if (sort) {
    both <- both %>% dplyr::arrange(-prop_missing)
  }
  if (!return_count) {both %>% print(n = Inf); invisible(data)}
  else if (return_count) both
}


gen_outliers <- function(df, y, thresh_val = 2) {
  df %>%
    ungroup %>%
    mutate(median_y = median({{y}}, na.rm = TRUE),
           sd_y = sd({{y}}, na.rm = TRUE)) %>%
    mutate(
      outlier = case_when(
        {{y}} > median_y + thresh_val * sd_y ~ "upper_outlier",
        {{y}} < median_y - thresh_val * sd_y ~ "lower_outlier"
      )
    )
}


print_names <- function(d) {
  d %>% names %>% enframe %>% select(name = value) %>% print_all
}


perc_y <- scale_y_continuous(labels = scales::percent)


corr_plot <- function(d) {
  d %>%
    cor(use = "pairwise.complete.obs") %>%
    round(3) %>%
    corrplot::corrplot(
      type="lower", tl.col="black",
      order = "hclust"
    )
}

factor_loadings <- function(d, n_factors) {
  psych::fa(
    r = d,
    nfactors = n_factors, fm = "minres", rotate = "oblimin"
  ) %>%
    .$loadings %>%
    unclass %>%
    as_tibble(rownames = "var") %>%
    arrange(desc(MR1))
}

hist_index <- function(df, value_var) {
  df %>%
    mutate(rev = factor(rev) %>% fct_relevel("TRUE")) %>%
    ggplot(aes(x = {{value_var}}, fill = rev)) + geom_histogram(show.legend = FALSE) + facet_wrap(~ name)
}


bern_vec <- function(n, p, NA_error = TRUE) {
  suppressWarnings(out <- as.logical(mc2d::rbern(n, p = p)))
  if (sum(is.na(out)) > 0 && NA_error) {
    print(n)
    print(p)
    stop("Bernoulli calc is causing NAs")
  }
  out
}

# Sample from existing data to get them, and update the group variables as well for FEs
sample_n_groups_w_replacement <- function(data, n) {

  group_var <- dplyr::group_vars(data)
  data_ids <- data %>% dplyr::ungroup() %>% dplyr::select(all_of(group_var)) %>%
    dplyr::distinct() %>% dplyr::slice_sample(n = n, replace = TRUE) %>%
    mutate("{group_var}__new" := paste0(!!sym(group_var), "_", row_number()))
  group_var_new <- paste0(group_var, "__new")
  sampled_df <- data_ids %>%
    left_join(data, by = group_var, relationship =
      "many-to-many") %>%
    ungroup %>%
    mutate("{group_var}" := !!sym(group_var_new)) %>%
    group_by(!!sym(group_var)) %>%
    select(-!!sym(group_var_new))
}


randomly_select_in_group <- function(data, n, sampled_var = "sampled_var") {

  data_sampled <- data %>%
    slice_sample(n = n) %>%
    mutate(
      "{sampled_var}" := TRUE
    )

  suppressMessages(full_join(data, data_sampled)) %>%
    mutate("{sampled_var}" := ifelse(is.na(!!sym(sampled_var)), FALSE, !!sym(sampled_var)))

}



prep_df <- function(df, var_regex, rev_vec, max_val, min_val = 0, na_vals = NULL) {

  df_1 <- df %>% select(KEY, matches(var_regex)) %>%
    select(-matches("_label$")) %>%
    pivot_longer(matches(var_regex)) %>%
    mutate(rev = name %in% rev_vec,
           name = ifelse(rev, str_glue("{name}_REV"), name))

  if (!is.null(na_vals)) df_1 <- df_1 %>% group_by(name) %>% mutate(value = ifelse(value %in% na_vals, median_na(value), value)) %>% ungroup

  df_1 %>%
    group_by(KEY) %>%
    mutate(mean_val = mean(value),
           resid_val = value - mean_val,
           val_rev = ifelse(rev, max_val - (value - min_val), value)) %>%
    ungroup %>% count_prop(rev, value, val_rev) %>%
    group_by(KEY) %>%
    mutate(
      mean_val_pos_coded = mean(val_rev[!rev]),
      mean_val_neg_coded = mean(val_rev[rev])
    ) %>%
    mutate(
      acqui_bias = (mean_val_pos_coded - mean_val_neg_coded) / 2,
      val_acqui = ifelse(rev, val_rev + acqui_bias, val_rev - acqui_bias)
    ) %>%
    ungroup
}

make_wide <- function(df, value_var) {
  df %>%
    select(KEY, name, {{value_var}}) %>%
    pivot_wider(names_from = name, values_from = {{value_var}}) %>%
    select(-KEY)
}


n_factors <- function(df) {
  psych::fa.parallel(df, fm = "minres", fa = "fa")
}

f_test <- function(model, ...) {
  fixest::wald(model,  ...)
}

# Get heterogeneity by amount of pro-trans discussion
mean_cl_cluster <- function(x, cluster) {
  data_temp <- tibble(
    x = x,
    cluster_ = cluster
  )
  lfe::felm(x ~ 1 | 0 | 0 | cluster_, data = data_temp) %>% tidy_90 %>%
    select(estimate, matches("conf"))
}

# FUNCTIONS for inverse covariance weighting ------------------------------

# Like "rownonmiss" in stata - counts number of non-NAs by row in the selected columns
row_non_nas <- function(dat, ..., var = "n_non_na") {
  dat_selected <- dat %>%
    select(...)

  non_miss_col <- rowSums(!is.na(dat_selected))
  non_miss_col

  dat_merged <- dat %>%
    mutate("{var}" := non_miss_col)

  dat_merged
}

# Like "rowmiss" in Stata - counts number of NAs by row in the selected columns
row_nas <- function(dat, ..., var = "n_na") {
  dat_selected <- dat %>%
    select(...)

  miss_col <- rowSums(is.na(dat_selected))
  miss_col

  dat_merged <- dat %>%
    mutate("{var}" := miss_col)

  dat_merged
}

# Calculates a rowmean on selected vars within a pipeline
row_mean <- function(dat, ..., var = "row_mean", na.rm = TRUE) {
  dat_selected <- dat %>%
    select(...)

  n_miss <- sum(rowSums(is.na(dat_selected)) > 0)
  if (na.rm == TRUE & n_miss > 0) warning(n_miss, " rows with NAs")
  mean_col <- rowMeans(dat_selected, na.rm = na.rm)
  mean_col <- ifelse(is.nan(mean_col), NA, mean_col)

  dat_merged <- dat %>%
    mutate("{var}" := mean_col)

  dat_merged

}

# Function that generates inverse covariance weights (from Anderson)
# Inpsired by:
# https://github.com/cdsamii/make_index/blob/master/r/index_comparison.R
inverse_cov_index <- function(dat_raw, weight_var, impute_median = TRUE) {

  dat_raw_id <- dat_raw %>%
    mutate(row_id = row_number())

  dat_filt <- dat_raw_id %>%
    filter(!is.na({{weight_var}}))

  # Impute any missing values for X vars
  count_nas <- dat_filt %>% select(-{{weight_var}}, -row_id) %>%
    as.matrix()

  n_nas <- sum(is.na(count_nas))

  if (n_nas > 0) {

    # Count number of variables
    n_vars <- ncol(dat_filt %>% select(-c({{weight_var}}, row_id)))

    # Remove obs that have all NAs
    dat_filt <- dat_filt %>% row_nas(-c({{weight_var}}, row_id))
    n_all_nas <- sum(dat_filt$n_na != dat_filt$n_vars, na.rm = T)
    if (n_all_nas > 0) print("Removing ", sum(dat_filt$n_na != dat_filt$n_vars, na.rm = T), " obs due to NAs in all rows")
    dat_filt <- dat_filt %>% filter(n_na != n_vars)

    if (impute_median) {
      # Impute medians in remaining NAs
      n_nas <- sum(colSums(is.na(dat_filt %>% select(-c({{weight_var}}, row_id)))))

      dat_filt <- dat_filt %>% row_mean(-c({{weight_var}}, row_id, n_na), var = "mean") %>%
        mutate(across(-c({{weight_var}}, row_id, n_na, mean),
                      ~ if_else(is.na(.), mean, .)))

      if (n_nas > 0) print(paste("Imputing", n_nas, "missing values with the median"))
    } else if (impute_median == F) {
      dat_filt <- drop_na_count(dat_filt)
    }

    dat_filt <- dat_filt %>% select(-n_na)

  }


  # Getting only X vars
  dat_matrix <- dat_filt %>%
    select(-{{weight_var}}, -row_id, -any_of("mean")) %>%
    as.matrix()

  # Getting weights matrix
  weights <- dat_filt %>%
    select({{weight_var}}) %>%
    deframe()

  # Generate vcov matrix
  v_cov <- cov.wt(dat_matrix, wt = weights)$cov
  one_vec <- as.matrix(rep(1,ncol(dat_matrix)))  # list of 1s (as long as there are columns in X)

  # Generate inverse-covariance weights
  inv_cov_weights <- solve(t(one_vec)%*%solve(v_cov)%*%one_vec)%*%t(one_vec)%*%solve(v_cov)
  index_score <- t(solve(t(one_vec)%*%solve(v_cov)%*%one_vec)%*%t(one_vec)%*%solve(v_cov)%*%t(dat_matrix))  # Generate index based on the weights

  print(inv_cov_weights)

  # Add back onto the data
  dat_index_score <- dat_filt %>%
    mutate(index_score = as.vector(index_score))

  # Add back onto raw data (deals with missing weights and filtered NAs)
  dat_index_score_merged <- dat_raw_id %>%
    left_join(dat_index_score, by = "row_id")

  return(
    list(inv_cov_weights = inv_cov_weights,
         index_score = dat_index_score_merged$index_score)
  )

}

# Adds to a dataframe
add_inverse_cov_index <- function(dat, ..., var = "index_score", weight_var = NULL, impute_median = TRUE) {


  dat_raw <- dat %>%
    select(..., {{weight_var}})

  inv_cov <- inverse_cov_index(dat_raw, weight_var = {{weight_var}}, impute_median = impute_median)

  dat_added <- dat %>%
    mutate("{var}" := inv_cov$index_score)

  dat_added

}

# Generate zscore
z_calc <- function(x, mean, sd) {
  z <- (x - mean) / sd
  z
}

# Generate zscore from own mean/sd
z_calc_std <- function(x, na.rm = TRUE, weights = NULL) {
  n_nas <- sum(is.na(x))

  if (is.null(weights)) {
    z <- z_calc(x = x, mean = mean(x, na.rm = na.rm), sd = sd(x, na.rm = na.rm))
  } else {
    z <- z_calc(x = x,
                mean = datawizard::weighted_mean(x, w = weights, na.rm = na.rm),
                sd = datawizard::weighted_sd(x, w = weights, na.rm = na.rm))
  }

  z

}

# Claculates zscore from control group mean/sd
z_calc_control <- function(x, x_control) {
  control_mean <- mean(x_control, na.rm = TRUE)
  control_sd <- sd(x_control, na.rm = TRUE)
  z <- (x - control_mean) / control_sd
  return(z)
}

# Z score measure that excludes 
z_calc_excl <- function(v) {

  # v <- c(5, 5, 0, 2, 2, 4)

  v_means <- list()
  v_sds <- list()

  for (i in seq_along(v)) {

    excl_v <- v[-c(i)]
    v_means[[i]] <- mean(excl_v, na.rm = TRUE)
    v_sds[[i]] <- sd(excl_v, na.rm = TRUE)

  }

  results <- tibble(
    raw_v = v,
    mean = flatten_dbl(v_means),
    sd = flatten_dbl(v_sds)
  ) %>%
    rowwise() %>%
    mutate(z = z_calc(raw_v, mean, sd))

  results$z

}

expand_select_multiple <- function(df, x_split) {
  vals <- df %>% pull({{x_split}}) %>% unlist() %>% unique() %>% sort()

  df_out <- df

  for (i in vals) {
    i_sanitised <- i %>% as.character() %>% str_replace_all("-", "_")
    df_out <- df_out %>%
      rowwise() %>%
      mutate(
        "{{x_split}}_{i_sanitised}" := i %in% {{x_split}}
      )
  }

  return(df_out)
}

# From a dep-means expression, find all the variables used in it
# (to check whether they exist in the database in the glance_dep_means function)
expression_to_vars <- function(expr) {
  expr %>% rlang::parse_expr() %>% all.names() %>% unique() %>% .[! . %in% ls("package:base")]
}


glance_dep_means <- function(x, dep_means = list("Mean: No discussion (private)" = "group_label == 'no_discuss'"), ...) {

  y <- x$fml %>% all.vars() %>% .[[1]]

  # Get the data
  d <- x$data %>% ungroup()

  # How many people?
  n_ind <- d %>% summarise(n = n_distinct(ind_id)) %>% pull(n) %>% as.integer()
  n_groups <- d %>% summarise(n = n_distinct(group_id)) %>% pull(n) %>% as.integer()

  out <- tibble(
    `N participants` = n_ind,
    `N groups` = n_groups
  )

  # Add on lasso controls
  if (x$lasso) {
    out["LASSO controls"] <- "X"
  } else {
    out["LASSO controls"] <- NA
  }

  print(dep_means)
  if (!is.null(dep_means)) {
    for (i in seq_along(dep_means)) {

      dep_mean_name <- names(dep_means)[[i]]
      dep_mean_exp <- dep_means[[i]]

      # Check that all the variables in the expression exist in the data
      vars <- expression_to_vars(dep_mean_exp)
      if (any(! vars %in% names(d))) {
        warning(str_glue("The following variables in the expression '{dep_mean_exp}' are not in the data: {vars[! vars %in% names(d)]}"))
        control_mean <- NA
      }

      else {
        control_data <- d %>% filter(!!rlang::parse_expr(dep_mean_exp))

        if (nrow(control_data) > 0) {
          control_mean <- mean_na(control_data[[y]])
        } else {
          control_mean <- NA
        }
      }


      out[dep_mean_name] <- control_mean
    }
  }

  print(out)

  return(out)

}


vec_to_custom_header <- function(x) {
  x <- as.character(x)
  x_diff <- x != lag(x)
  x_diff[[1]] <- TRUE

  labels <- x[x_diff]
  starts <- which(x_diff)
  ends <- c((starts - 1)[-1], length(x))

  len <- (ends - starts) + 1 # length of each group
  names(len) <- labels
  return(len)

}

vec_to_add_rows <- function(v) {
  enframe(v) %>% pivot_wider(names_from = name, values_from = value)
}


write_percentage <- function(x, file, digits = 0) {
  print(x)
  x <- (x*100) %>% round(digits = digits)

  if (digits > 0) x <- x %>% format(digits = digits, nsmall = digits)
  else x <- x %>% as.character()
  x %>%
    as.character() %>%
    paste0(., "\\%%") %>%
    writeLines(file)
}

write_stat <- function(x, file, digits = 2, p_value = FALSE) {
  print(x)
  if (p_value & x < 0.001) {
    x <- "$<$0.001"
  } else {
    x <- x %>% round(digits = digits)

    if (digits > 0) x <- x %>% format(digits = digits, nsmall = digits)
    else x <- x %>% as.character()
  }

  x %>%
    paste0(., "%") %>%
    writeLines(file)
}

write_p_val <- function(x, file, digits = 2) {
  write_stat(x, file, digits, p_value = TRUE)
}

write_range <- function(x, file, digits = 2) {
  min_x <- min_na(x)
  max_x <- max_na(x)

  min_x <- min_x %>% format(digits = digits, nsmall = digits)
  max_x <- max_x %>% format(digits = digits, nsmall = digits)

  paste0(min_x, ", ", max_x) %>%
    paste0(., "%") %>%
    writeLines(file)
}

tex_export <- function(models, file = NULL,
                       column_widths = NULL,
                       stars = c("*" = .1, "**" = 0.05, "***" = 0.01),
                       gof_omit = "AIC|IC|RMSE|Std|R2 Adj",
                       dep_means = NULL,
                       additional_header = NULL,
                       coef_rename = NULL,
                       coef_reorder = NULL,
                       add_rows = NULL,
                       controls_row = FALSE,
                       ...) {

  # Prepare headers:
  model_names <- names(models)
  header_vec <- vec_to_custom_header(model_names)

  # Make the model names into column indices (1), (2) etc.
  names(models) <- str_glue("({seq_along(models)})")


  if (controls_row) {
    # Add a row for the controls
    if (is.null(add_rows)) add_rows <- list_to_add_rows(list("Controls" = rep("X", length(models))))
    else {
      add_rows <- combine_add_rows(
        list_to_add_rows(list("Controls" = rep("X", length(models)))),
        add_rows
      )
    }
  }

  glance_custom.fixest <<- function(x, ...) {
    glance_dep_means(x, dep_means = dep_means, ...)
  }

  if (is.null(list(...)[["stat_vec"]])) {
    stat_vec <- "({std.error}) [{p.value}]"
  } else if (list(...)[["stat_vec"]] == "wide") {
    stat_vec <- "({std.error}) [{p.value}]"
  } else if (list(...)[["stat_vec"]] == "long") {
    stat_vec <- c("({std.error})", "[{p.value}]")
  } else {
    stat_vec <- list(...)[["stat_vec"]]
  }

  # If you need to reorder, then create coef_map:
  if (!is.null(coef_reorder)) {
    # (a) Get default order by running with coef_rename = NULL

    default_order <- modelsummary(
      models,
      stars = stars,
      gof_omit = gof_omit,
      statistic = stat_vec,
      escape = FALSE,
      booktabs = TRUE,
      output = "data.frame",
      add_rows = NULL,
      coef_rename = NULL,
      ...
    ) %>%
      filter(part == "estimates") %>%
      pull(term) %>% unique()

    if (any(!coef_reorder %in% default_order)) {
      print(str_glue("coef_reorder: {coef_reorder}"))
      print(str_glue("default_order: {default_order}"))
      stop("coef_reorder terms not in the specification")
    }

    # (b) Create new order
    new_order <- c(coef_reorder, setdiff(default_order, coef_reorder))


    # (c) Create coef_map based on this
    coef_map <- coef_label(new_order)
    names(coef_map) <- new_order
    coef_rename <- NULL

  } else {
    coef_map <- NULL
  }


  print(list(...)["dep_means"])

  if (is.null(file)) {
    modelsummary(
      models,
      stars = stars,
      gof_omit = gof_omit,
      statistic = stat_vec,
      escape = FALSE,
      booktabs = TRUE,
      coef_rename = coef_rename,
      coef_map = coef_map,
      add_rows = add_rows,
      ...
    )
  } else {

    out_w_table <- modelsummary(
      models,
      output = "latex",
      stars = stars,
      gof_omit = gof_omit,
      statistic = stat_vec,
      escape = FALSE,
      booktabs = TRUE,
      coef_rename = coef_rename,
      coef_map = coef_map,
      add_rows = add_rows,
      ...
    )

    out_w_table <- out_w_table %>% kableExtra::add_header_above(c(" " = 1, header_vec))

    if (!is.null(additional_header)) {
      out_w_table <- out_w_table %>% kableExtra::add_header_above(additional_header)
    }

    if (!is.null(column_widths)) {
      out_w_table <- out_w_table %>% kableExtra::column_spec(column = column_widths[[1]], width = column_widths[[2]])
    }

    out_w_table %>%
      str_replace_all("\\\\begin\\{table\\}\n", "") %>%
      str_replace_all("\\\\end\\{table\\}", "") %>%
      str_replace_all(fixed("\\multicolumn{2}{l}{\\rule{0pt}{1em}* p $<$ 0.1, ** p $<$ 0.05, *** p $<$ 0.01}\\\\"), "") %>%
      str_replace_all(fixed("\\multicolumn{3}{l}{\\rule{0pt}{1em}* p $<$ 0.1, ** p $<$ 0.05, *** p $<$ 0.01}\\\\"), "") %>%
      str_replace_all(fixed("\\multicolumn{4}{l}{\\rule{0pt}{1em}* p $<$ 0.1, ** p $<$ 0.05, *** p $<$ 0.01}\\\\"), "") %>%
      str_replace_all(fixed("\\multicolumn{5}{l}{\\rule{0pt}{1em}* p $<$ 0.1, ** p $<$ 0.05, *** p $<$ 0.01}\\\\"), "") %>%
      str_replace_all(fixed("\\multicolumn{6}{l}{\\rule{0pt}{1em}* p $<$ 0.1, ** p $<$ 0.05, *** p $<$ 0.01}\\\\"), "") %>%
      str_replace_all(fixed("\\multicolumn{7}{l}{\\rule{0pt}{1em}* p $<$ 0.1, ** p $<$ 0.05, *** p $<$ 0.01}\\\\"), "") %>%
      str_replace_all(fixed("\\multicolumn{8}{l}{\\rule{0pt}{1em}* p $<$ 0.1, ** p $<$ 0.05, *** p $<$ 0.01}\\\\"), "") %>%
            str_replace_all(fixed("\\multicolumn{9}{l}{\\rule{0pt}{1em}* p $<$ 0.1, ** p $<$ 0.05, *** p $<$ 0.01}\\\\"), "") %>%
            str_replace_all(fixed("\\multicolumn{10}{l}{\\rule{0pt}{1em}* p $<$ 0.1, ** p $<$ 0.05, *** p $<$ 0.01}\\\\"), "") %>%
      str_replace_all(fixed("\\multicolumn{11}{l}{\\rule{0pt}{1em}* p $<$ 0.1, ** p $<$ 0.05, *** p $<$ 0.01}\\\\"), "") %>%
      str_replace_all("\\\\caption\\{.*\\}", "") %>%
      # cat() %>%
      writeLines(file)
  }
}


kable_remove_table <- function(x) {
  x %>%
    str_replace_all("\\\\begin\\{table\\}\n", "") %>%
    str_replace_all("\\\\end\\{table\\}", "")
}

fml_sum <- function(x) paste0(x, collapse = " + ")

# Convert formula to a string with that formula
fml_to_char <- function(x) {
  ch <- x %>% as.character()
  paste0(ch[[2]], " ~ ", ch[[3]])
}

vars_to_regex <- function(...) {
  paste0(c(...), collapse = "|")
}

kable_balance <- function(x) {
  options(knitr.kable.NA = '')

  which_means <- which(str_detect(names(x), "Mean"))
  vec_header <- vec_to_custom_header(
    c(" ",
      rep("Means", length(which_means)),
      rep("p-values", ncol(x) - length(which_means) - 1)
    )
  )

  vec_header_2 <- vec_to_custom_header(
    c(" ",
      paste0("(", 1:length(which_means), ")"),
    rep(" ", ncol(x) - length(which_means) - 1))
  )

  x %>%
    rename_with(~ .x %>% str_replace_all(
      "Mean ", ""
    ) %>% str_replace_all(
      "p-value for test of:", ""
    )) %>%
    kable(digits = 2, linesep = "", booktabs = TRUE, "latex") %>%
    kable_styling() %>%
    add_header_above(vec_header_2) %>%
    add_header_above(vec_header) %>%
    row_spec((nrow(x)-2), hline_after = TRUE) %>%
    row_spec((nrow(x)-1):nrow(x), bold = TRUE) %>%
    kable_remove_table()
}


is.formula <- function(x) inherits(x, "formula")


# FEOLS that returns the data
feols_custom <- function(fml, data, cluster = NULL, ri = FALSE, lasso = FALSE, lasso_options = NULL, quiet = FALSE, ...) {
  vars <- c(fml_to_vars(fml), cluster)

  if (!is.null(list(...)[["fixef"]])) {
    vars <- c(vars, list(...)[["fixef"]])
  }

  if (!is.null(list(...)[["cluster"]])) {
    vars <- c(vars, list(...)[["cluster"]])
  }

  vars <- vars %>% map(as.character) %>%
    map(~ .x[.x != "~"]) %>%
    flatten_chr %>%
    .[. != "1"] %>%
    str_split(":") %>%
    flatten_chr()

  nas <- data %>% select(any_of(vars)) %>%
    count_nas_quiet(return_count = TRUE) %>%
    filter(missing > 0)

  if (nrow(nas) > 0 & !quiet) {
    warning(str_glue("There are {nrow(nas)} variables with missing values in the data."))
    print(nas)
  }

  if (lasso) {

    if (is.null(lasso_options)) {
      stop("If lasso is TRUE, you must provide lasso_options")
    }

    lasso_controls <- lasso_options[["t"]] %>%
      map(
        ~ {
          lasso_controls <- get_lasso_controls(
            potential_controls = lasso_options[["potential_controls"]],
            y = as.character(fml)[[2]],
            t = .x,
            df = data,
            interact = lasso_options[["interact"]],
            group_control = lasso_options[["group_control"]],
            quiet = quiet
          )
        }
      ) %>%
      flatten_chr() %>%
      unique()


    # Update data to include possible controls
    if (!is.null(lasso_options[["interact"]])) {
      interact <- lasso_options[["interact"]]

      data <- data %>%
        mutate(
          across(any_of(lasso_options[["potential_controls"]]),
                 ~ .x * as.numeric(!!sym(interact)),
                 .names = "{.col}__{interact}")
        )
    }

    if (!is.null(lasso_options[["group_control"]])) {
      group_control <- lasso_options[["group_control"]]

      data <- data %>%
        group_by(group_id) %>%
        mutate(
          across(any_of(lasso_options[["potential_controls"]]),
                 list(group_control = ~ (sum_na(.x) - .x) / (n() - 1)),
                 .names = "{.col}__group_control")
        ) %>%
        ungroup()
    }

    if (!quiet) {
      print(str_glue("lasso_controls: {paste0(lasso_controls, collapse = ', ')}"))
    }

    if (length(lasso_controls) == 0) {
      warning("No lasso controls found. Using original formula.")
    } else {
      fml <- as.formula(paste0(fml_to_char(fml), " + ", fml_sum(lasso_controls)))
    }

  }

  if (!quiet) {
    model <- feols(
      fml = fml,
      data = data,
      cluster = cluster,
      ...
    )
  } else {
    suppressMessages(suppressWarnings(
      model <- feols(
        fml = fml,
        data = data,
        cluster = cluster,
        ...
      )
    ))
  }


  if (ri) {
    ri_out <- ri_custom_v2(
      df = data,
      model = model,
      cluster = cluster,
      ...
    )

    model$ri_out <- ri_out
  }

  model$data <- data
  model$cluster <- cluster

  if (lasso) {
    model$lasso_controls <- lasso_controls
    model$lasso <- TRUE
  } else {
    model$lasso_controls <- NULL
    model$lasso <- FALSE
  }

  return(model)
}

tidy_custom.fixest <- function(x, ...) {
  out <- broom::tidy(x)
  ri_output <- x$ri_out

  if (!is.null(ri_output)) {
    out <- out %>%
      left_join(ri_output %>% select(term, ri_p = p), by = c("term")) %>%
      mutate(p.value = coalesce(ri_p, p.value))
  }

  return(out)
}








get_comparison_p_vals_OLD <- function(model, var) {
  # Use a model, and the var to look at all the levels for, to get all the possible p-values of comparison

  d <- model$data

  d[[var]] <- d[[var]] %>% factor()

  levs <- d[[var]] %>% levels()

  levs_others <- levs[-1]

  out <- tibble()

  for (i in seq_along(levs_others)) {

    d_alt <- d
    d_alt[[var]] <- d[[var]] %>% fct_relevel(levs_others[[i]])

    model_for_p <- feols_custom(
      fml = model$fml,
      data = d_alt,
      fixef = model$fixef_vars,
      cluster = model$cluster
    ) %>%
      tidy_90 %>%
      filter(str_detect(term, var)) %>%
      select(term, p.value) %>%
      mutate(
        base = levs_others[[i]],
      ) %>%
      relocate(base)

    out <- bind_rows(out, model_for_p)

  }

  return(out)

}


get_comparison_p_vals <- function(model, dummy_vars, ri = FALSE, control_group = NULL, keep_vars = NULL, ...) {

  fml <- model$fml

  # Which is omitted category?
  base_0 <- dummy_vars[!(dummy_vars %in% fml_to_vars(fml))]

  if (length(base_0) == 0) stop("Need to also include omitted category in vars argument")

  # Which are other potential categories
  base_others <- dummy_vars[dummy_vars != base_0]

  # Write a regex that finds the string in teh formula with these vars
  regex_others <- base_others %>% paste0(collapse = "\\s+\\+\\s+")

  out <- tibble()

  for (i in seq_along(base_others)) {

    # i <- 1
    new_base <- base_others[[i]]
    new_cats <- dummy_vars[dummy_vars != new_base]
    new_cats_str <- new_cats %>% paste0(collapse = " + ")

    fml_str <- fml %>% fml_to_char()

    if (!str_detect(fml_str, regex_others)) {
      stop(str_glue("Looking for {regex_others} in {fml_str}, no match found, might need to reorder variables in vars"))
    }

    fml_str_new <- str_replace(fml_str, regex_others, new_cats_str)

    fml_new <- fml_str_new %>% as.formula()

    print(str_glue("fml_new: {fml_new}"))

    if (ri == TRUE) {
      # RI is TRUE

      # Generate new control group specification
      control_group_original <- control_group

      # New control group spec - swap the control and the new base
      control_group_new <- control_group_original %>%
        map(~ str_replace(., str_glue("^{new_base}$"), base_0)) %>%
        set_names(names(control_group_original) %>% str_replace(str_glue("^{base_0}$"), new_base)) %>%
        print

      model_for_p <- feols_custom(
        fml = fml_new,
        data = model$data,
        fixef = model$fixef,
        cluster = model$cluster,
        ri = TRUE,
        control_group = control_group_new,
        ...
      ) %>%
        .$ri_out %>%
        select(term, p.value = p) %>%
        mutate(base = new_base) %>%
        relocate(base) %>%
        print
    } else {
      # RI = FALSE
      model_for_p <- feols_custom(
        fml = fml_new,
        data = model$data,
        fixef = model$fixef,
        cluster = model$cluster
      ) %>%
        tidy_90 %>%
        filter(str_detect(term, vars_to_regex(dummy_vars, keep_vars))) %>%
        select(term, p.value) %>%
        mutate(base = new_base) %>%
        relocate(base)
    }

    out <- bind_rows(out, model_for_p)

  }

  return(out)
}

filter_interact_terms <- function(df, vars, interact) {

  var_regex <- paste0("(", vars_to_regex(vars), ")")
  interact_exists <- df %>% pull(term) %>% str_detect(str_glue("{interact}\\:{var_regex}")) %>% any()

  print(interact_exists)

  if (interact_exists) {
    out <- df %>% tidylog::filter(str_detect(term, str_glue("{interact}\\:{var_regex}")))
  } else {
    out <- df %>% filter(str_detect(term, paste0("(", vars_to_regex(paste0("^", vars, "$")), ")")))
  }

  out <- out %>%
    mutate(term = str_replace_all(term, str_glue("{interact}\\:"), ""))

  return(out)
}


# Keep the p-values for the phase-2 comparisons we care about

# Convert the pvalues models to a single thing to be used in add_rows in modelsummary
p_vals_to_add_rows <- function(p_vals) {

  p_vals_clean <- p_vals %>%
    map(~ {
      .x %>%
        mutate(
          label = str_glue("p({base}={term})")
        ) %>%
        select(label, p.value)
    })

  bind_cols(
    p_vals_clean[[1]][[1]],
    p_vals_clean %>% map("p.value")
  )
}



# Mean that returns vector of means excluding the value at hand
mean_exclude <- function(x) {
  mean_vec <- numeric(length(x)) # create an empty vector to store the mean values

  for (i in seq_along(x)) {
    mean_vec[i] <- mean(x[-i], na.rm = TRUE)
  }

  return(mean_vec) # return the vector of mean values
}

list_to_add_rows <- function(l) {

  # l_names

  # Convert list to a tibble with 2 rows and 4 columns
  d <- matrix(NA, nrow = length(l), ncol = length(l[[1]]))

  for (i in seq_along(l)) {
    d[i, ] <- l[[i]]
  }

  d <- d %>%
    as_tibble() %>%
    mutate(name = names(l)) %>%
    relocate(name)

  d

}

combine_add_rows <- function(...) {
  list(...) %>%
    purrr::map(
      ~ .x %>%
        mutate(across(where(is.numeric), ~ ifelse(is.na(.x), " ", format(round(.x, 3), nsmall = 3)))) %>%
        mutate(across(everything(), as.character)) %>%
        set_names(paste0("V", ncol(.)))
    ) %>%
    bind_rows()
}


# ISSUE - how do you RI for multi arm trial???
# only want to swap out treatment and control AND LEAVE REST UNCHANGED
# so would need to filter dataset? Do the RI, and then add back in?

get_p_val <- function(model, term) {

  tidy <- model %>%
    tidy_custom.fixest()

  if ("ri_p" %in% names(tidy)) {
    tidy <- tidy %>%
      mutate(p.value = coalesce(ri_p, p.value))
  }

  tidy %>%
    filter(term == !!term) %>%
    pull(p.value)
}

get_mean <- function(df, term) {
  df %>%
    pull({ { term } }) %>%
    mean_na()
}


get_coeff <- function(model, term) {
  model$coefficients[[term]]
}


get_diff <- function(model, var_1, var_2) {

  coeff_1 <- get_coeff(model, var_1)
  coeff_2 <- get_coeff(model, var_2)
  diff <- coeff_2 - coeff_1

  vcov_model <- vcov(model)

  cov_12 <- vcov_model[var_1, var_2]
  v1 <- vcov_model[var_1, var_1]
  v2 <- vcov_model[var_2, var_2]

  se_diff <- sqrt(v1 + v2 - 2*cov_12)

  t_stat <- (coeff_2 - coeff_1) / se_diff

  deg_free <- degrees_freedom(model, "t")
  p_val <- 2 * (1 - pt(abs(t_stat), deg_free))

  return(list(diff = diff, se_diff = se_diff,
              t_stat = t_stat, p_val = p_val))

}


get_diff_p_vals <- function(model, var_1, var_2) {
  p_val <- get_diff(model, var_1, var_2)$p_val
  tibble(
    base = var_1,
    term = var_2,
    p.value = p_val
  )
}

gg_color_hue <- function(n) {
  hues = seq(15, 375, length = n + 1)
  hcl(h = hues, l = 65, c = 100)[1:n]
}



times_100 <- function(x) {
  x * 100
}